{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "67fb8ca0-39dd-4ba5-b400-0ddc9d60e387"
   },
   "source": [
    "## MNIST-DNN 简单DNN完成手写数字识别\n",
    "\n",
    "这个简单用例用于展示如何利用Alibaba的PAI-DSW（Data Science Workspace）编写一个简单DNN的模型来识别手写数字，也就是深度学习中的HelloWorld项目。\n",
    "\n",
    "1. 首先我们下载MNIST需要的数据，这个数据包括大量28x28图片，我们把数据分为2个数据集，55000个训练数据、10000个测试数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "uuid": "6cf0265c-85d5-4668-943d-fb51a7e97b3b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "(55000, 784) (55000, 10)\n",
      "(10000, 784) (10000, 10)\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)\n",
    "\n",
    "print(mnist.train.images.shape, mnist.train.labels.shape)\n",
    "print(mnist.test.images.shape, mnist.test.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "c89dd01c-c7c7-4faf-9d9e-ce59567ec4f0"
   },
   "source": [
    "2. 我们构建一个简单1层的DNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "4a43b63a-b7c7-4028-a74b-be4a1bfd42b2"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "sess = tf.InteractiveSession()\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "\n",
    "W = tf.Variable(tf.zeros([784, 10]))\n",
    "b = tf.Variable(tf.zeros([10]))\n",
    "\n",
    "y = tf.nn.softmax(tf.matmul(x, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "ccd16048-66bd-43c7-8267-b7e844ec9bed"
   },
   "source": [
    "3. 我们定义一个损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "uuid": "618c5569-c935-4860-b1d3-78e48db43972"
   },
   "outputs": [],
   "source": [
    "y_ = tf.placeholder(tf.float32, [None, 10])\n",
    "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "95fc4125-ea00-4e4a-8a63-79506ddc5375"
   },
   "source": [
    "4. 我们选择一个训练器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "uuid": "2c9b5832-6ca7-42d2-a201-d1f20e254075"
   },
   "outputs": [],
   "source": [
    "train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "a66ab3f8-bbc9-40f9-9a5e-be0f7dc8b508"
   },
   "source": [
    "5. 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "uuid": "759432ea-cc44-449d-a1a7-b63e39992d94"
   },
   "outputs": [],
   "source": [
    "\n",
    "tf.global_variables_initializer().run()\n",
    "\n",
    "for i in range(1000):\n",
    "    batch_xs, batch_ys = mnist.train.next_batch(100)\n",
    "    train_step.run({x: batch_xs, y_: batch_ys})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "3479b93a-70e5-49f2-925e-29d9220bb622"
   },
   "source": [
    "6. 输出验证结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "uuid": "79514804-a9d2-40b4-b91e-136d1eea25e2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9203\n"
     ]
    }
   ],
   "source": [
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "aea5f0ef-faf7-41df-9d4f-35849b422a54"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
